{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Automatic Differentiation\n",
    ":label:`sec_autograd`\n",
    "\n",
    "As we have explained in :numref:`sec_calculus`,\n",
    "differentiation is a crucial step in nearly all deep learning optimization algorithms.\n",
    "While the calculations for taking these derivatives are straightforward,\n",
    "requiring only some basic calculus,\n",
    "for complex models, working out the updates by hand\n",
    "can be a pain (and often error-prone).\n",
    "\n",
    "Deep learning frameworks expedite this work\n",
    "by automatically calculating derivatives, i.e., *automatic differentiation*.\n",
    "In practice,\n",
    "based on our designed model\n",
    "the system builds a *computational graph*,\n",
    "tracking which data combined through\n",
    "which operations to produce the output.\n",
    "Automatic differentiation enables the system to subsequently backpropagate gradients.\n",
    "Here, *backpropagate* simply means to trace through the computational graph,\n",
    "filling in the partial derivatives with respect to each parameter.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/Functions.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## A Simple Example\n",
    "\n",
    "As a toy example, say that we are interested\n",
    "in differentiating the function\n",
    "$y = 2\\mathbf{x}^{\\top}\\mathbf{x}$\n",
    "with respect to the column vector $\\mathbf{x}$.\n",
    "To start, let us create the variable `x` and assign it an initial value.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray x = manager.arange(4f);\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Before we even calculate the gradient\n",
    "of $y$ with respect to $\\mathbf{x}$,\n",
    "we will need a place to store it.\n",
    "It is important that we do not allocate new memory\n",
    "every time we take a derivative with respect to a parameter\n",
    "because we will often update the same parameters\n",
    "thousands or millions of times\n",
    "and could quickly run out of memory.\n",
    "Note that a gradient of a scalar-valued function\n",
    "with respect to a vector $\\mathbf{x}$\n",
    "is itself vector-valued and has the same shape as $\\mathbf{x}$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// We allocate memory for a NDArrays's gradient by invoking `setRequiresGradient(true)`\n",
    "x.setRequiresGradient(true);\n",
    "// After we calculate a gradient taken with respect to `x`, we will be able to\n",
    "// access it via the `getGradient` attribute, whose values are initialized with 0s\n",
    "x.getGradient()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "We place our code inside a _try-with-resources_ and declare the GradientCollector object that will build the computational graph.\n",
    "Now let us calculate $y$.\n",
    "\n",
    "Since `x` is a vector of length 4,\n",
    "an inner product of `x` and `x` is performed,\n",
    "yielding the scalar output that we assign to `y`.\n",
    "Next, we can automatically calculate the gradient of `y`\n",
    "with respect to each component of `x`\n",
    "by calling the function for backpropagation and printing the gradient."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray y = x.dot(x).mul(2);\n",
    "    System.out.println(y);\n",
    "    gc.backward(y);\n",
    "}\n",
    "x.getGradient()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "The gradient of the function $y = 2\\mathbf{x}^{\\top}\\mathbf{x}$\n",
    "with respect to $\\mathbf{x}$ should be $4\\mathbf{x}$.\n",
    "Let us quickly verify that our desired gradient was calculated correctly.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "x.getGradient().eq(x.mul(4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "Now let us calculate another function of `x`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray y = x.sum();\n",
    "    gc.backward(y);\n",
    "}\n",
    "x.getGradient() // Overwritten by the newly calculated gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 28
   },
   "source": [
    "## Backward for Non-Scalar Variables\n",
    "\n",
    "Technically, when `y` is not a scalar,\n",
    "the most natural interpretation of the differentiation of a vector `y`\n",
    "with respect to a vector `x` is a matrix.\n",
    "For higher-order and higher-dimensional `y` and `x`,\n",
    "the differentiation result could be a high-order tensor.\n",
    "\n",
    "However, while these more exotic objects do show up\n",
    "in advanced machine learning (including in deep learning),\n",
    "more often when we are calling backward on a vector,\n",
    "we are trying to calculate the derivatives of the loss functions\n",
    "for each constituent of a *batch* of training examples.\n",
    "Here, our intent is not to calculate the differentiation matrix\n",
    "but rather the sum of the partial derivatives\n",
    "computed individually for each example in the batch.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// When we invoke `backward` on a vector-valued variable `y` (function of `x`),\n",
    "// a new scalar variable is created by summing the elements in `y`. Then the\n",
    "// gradient of that scalar variable with respect to `x` is computed\n",
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray y = x.mul(x); // y is a vector\n",
    "    gc.backward(y);\n",
    "}\n",
    "x.getGradient() // Overwritten by the newly calculated gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 32
   },
   "source": [
    "## Detaching Computation\n",
    "\n",
    "Sometimes, we wish to move some calculations\n",
    "outside of the recorded computational graph.\n",
    "For example, say that `y` was calculated as a function of `x`,\n",
    "and that subsequently `z` was calculated as a function of both `y` and `x`.\n",
    "Now, imagine that we wanted to calculate\n",
    "the gradient of `z` with respect to `x`,\n",
    "but wanted for some reason to treat `y` as a constant,\n",
    "and only take into account the role\n",
    "that `x` played after `y` was calculated.\n",
    "\n",
    "Here, we can detach `y` using `stopGradient` to return \n",
    "a new variable `u`\n",
    "that has the same value as `y` but discards any information\n",
    "about how `y` was computed in the computational graph.\n",
    "In other words, the gradient will not flow backwards through `u` to `x`.\n",
    "Thus, the following backpropagation function computes\n",
    "the partial derivative of `z = u * x` with respect to `x` while treating `u` as a constant,\n",
    "instead of the partial derivative of `z = x * x * x` with respect to `x`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray y = x.mul(x);\n",
    "    NDArray u = y.stopGradient();\n",
    "    NDArray z = u.mul(x);\n",
    "    gc.backward(z);\n",
    "    System.out.println(x.getGradient().eq(u));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 36
   },
   "source": [
    "We can subsequently invoke backpropagation on `y` to get the derivative of `y = x * x` with respect to `x`, which is `2 * x`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray y = x.mul(x);\n",
    "    y = x.mul(x);\n",
    "    gc.backward(y);\n",
    "    System.out.println(x.getGradient().eq(x.mul(2)));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 40
   },
   "source": [
    "## Computing the Gradient of Java Control Flow\n",
    "\n",
    "One benefit of using automatic differentiation\n",
    "is that even if building the computational graph of a function\n",
    "required passing through a maze of Java control flow\n",
    "(e.g., conditionals, loops, and arbitrary function calls),\n",
    "we can still calculate the gradient of the resulting variable.\n",
    "In the following snippet, note that\n",
    "the number of iterations of the `while` loop\n",
    "and the evaluation of the `if` statement\n",
    "both depend on the value of the input `a`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public NDArray f(NDArray a) {\n",
    "    NDArray b = a.mul(2);\n",
    "    while (b.norm().getFloat() < 1000) {\n",
    "        b = b.mul(2);\n",
    "    }\n",
    "    NDArray c;\n",
    "    if (b.sum().getFloat() > 0) {\n",
    "        c = b;\n",
    "    }else {\n",
    "        c = b.mul(100);\n",
    "    }\n",
    "    return c;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 44
   },
   "source": [
    "Let us compute the gradient.\n",
    "\n",
    "We can then analyze the `f` function defined above.\n",
    "Note that it is piecewise linear in its input `a`.\n",
    "In other words, for any `a` there exists some constant scalar `k`\n",
    "such that `f(a) = k * a`, where the value of `k` depends on the input `a`.\n",
    "Consequently `d / a` allows us to verify that the gradient is correct."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray a = manager.randomNormal(new Shape(1));\n",
    "a.setRequiresGradient(true);\n",
    "try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "    NDArray d = f(a);\n",
    "    gc.backward(d);\n",
    "    \n",
    "    System.out.println(a.getGradient().eq(d.div(a)));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 52
   },
   "source": [
    "## Summary\n",
    "\n",
    "* Deep learning frameworks can automate the calculation of derivatives. To use it, we first attach gradients to those variables with respect to which we desire partial derivatives. We then record the computation of our target value, execute its function for backpropagation, and access the resulting gradient.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Why is the second derivative much more expensive to compute than the first derivative?\n",
    "1. After running the function for backpropagation, immediately run it again and see what happens.\n",
    "1. In the control flow example where we calculate the derivative of `d` with respect to `a`, what would happen if we changed the variable `a` to a random vector or matrix. At this point, the result of the calculation `f(a)` is no longer a scalar. What happens to the result? How do we analyze this?\n",
    "1. Redesign an example of finding the gradient of the control flow. Run and analyze the result.\n",
    "1. Let $f(x) = \\sin(x)$. Plot $f(x)$ and $\\frac{df(x)}{dx}$, where the latter is computed without exploiting that $f'(x) = \\cos(x)$.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
